package org.spider.core.executor.shape;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.ibatis.jdbc.SQL;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.spider.api.context.SpiderContext;
import org.spider.api.context.SpiderContextHolder;
import org.spider.api.domain.utilDomain.KeyValue;
import org.spider.api.domain.utilDomain.SpiderNode;
import org.spider.api.domain.utilDomain.SpiderOutput;
import org.spider.api.executor.ShapeExecutor;
import org.spider.api.listener.TaskListener;
import org.spider.api.utils.io.FileUtils;
import org.spider.core.utils.DataSourceUtils;
import org.spider.core.utils.ExpressionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.*;

@Component
public class OutputExecutor implements ShapeExecutor, TaskListener {
    public static final String OUTPUT = "output";
    public static final String DATASOURCE_ID = "datasourceId";
    public static final String OUTPUT_DATABASE = "outputDatabase";
    public static final String OUTPUT_CSV = "outputCsv";
    public static final String TABLE_NAME = "tableName";
    public static final String CSV_NAME = "csvName";
    public static final String CSV_ENCODING = "csvEncoding";
    private static Logger logger = LoggerFactory.getLogger(OutputExecutor.class);
    @Value("${spider.fileDir}")
    private String fileDir;

    /**
     * 输出CSVPrinter节点变量
     */
    private Map<String, CSVPrinter> cachePrinter = new HashMap<>();

    @Override
    public void execute(SpiderNode node, SpiderContext context, Map<String, Object> variables) {
        logger.debug(this.getClass().getSimpleName() + "执行了");
        SpiderOutput output = new SpiderOutput();
        output.setNodeName(node.getNodeName());
        output.setNodeId(node.getNodeId());
        boolean databaseFlag = "true".equals(node.getStringByKey(OUTPUT_DATABASE));
        boolean csvFlag = "true".equals(node.getStringByKey(OUTPUT_CSV));
        // 这里的dict实际上是只有一对元素的Map
        List<KeyValue> outputList = node.getDictListByKey(OUTPUT);
        if (outputList == null || outputList.isEmpty()) {
            logger.error("输出列表为空,如果不需要设置值，请删掉该节点，否则请至少配置一个输出项");
            return;
        }
        List<KeyValue> outputFields = null;
        if (databaseFlag || csvFlag) {
            outputFields = new ArrayList<>(outputList.size());
        }
        for (KeyValue kv : outputList) {
            Object value = null;
            String outputName = kv.getKey();
            String outputValue = kv.getValue();
            try {
                value = ExpressionUtils.execute(outputValue, variables);
                variables.put(outputName,value);
                logger.info("输出{}={} ,expression={}", outputName, value.toString(),outputValue);
            } catch (Exception e) {
                logger.error("输出{}出错,异常信息:{}", outputName, e);
            }
            output.addOutput(outputName,value!=null?  value.toString() : "");
            // csv和database的表格
            if ((databaseFlag || csvFlag) ) {
                outputFields.add(new KeyValue(outputName,value!=null ?value.toString() : ""));
            }
        }
        if (databaseFlag) {
            String dsId = node.getStringByKey(DATASOURCE_ID);
            String tableName = node.getStringByKey(TABLE_NAME);
            if (StringUtils.isBlank(dsId)) {
                logger.warn("数据源ID为空！");
            } else if (StringUtils.isBlank(tableName)) {
                logger.warn("表名为空！");
            } else {
                outputDB(dsId, tableName, outputFields);
            }
        }
        if (csvFlag) {
            String csvName = node.getStringByKey(CSV_NAME);
            myOutputCsv(node, context, csvName, outputFields);
        }
//        logger.info("chapter_title类型为:{}",variables.get("chapter_title").getClass().getSimpleName());
        // webSocket输出
        context.addOutput(output);
    }

    private void outputDB(String databaseId, String tableName, List<KeyValue> outputFields) {
        if (outputFields == null || outputFields.isEmpty()) {
            return;
        }
        JdbcTemplate template = new JdbcTemplate(DataSourceUtils.getDataSource(databaseId));
        Object[] params = new Object[outputFields.size()];
        SQL sql = new SQL();
        //设置表名
        sql.INSERT_INTO(tableName);
        int index = 0;
        //设置字段名
        for(KeyValue kv:outputFields){
            sql.VALUES(kv.getKey(),"?");
            params[index] = kv.getValue();
            index++;
        }
        try {
            //执行sql
            template.update(sql.toString(), params);
        } catch (Exception e) {
            logger.error("执行sql出错,异常信息:{}", e.getMessage(), e);
            ExceptionUtils.wrapAndThrow(e);
        }
    }

    private void myOutputCsv(SpiderNode node, SpiderContext context, String csvName, List<KeyValue> outputFields) {
        if (outputFields == null || outputFields.isEmpty()) return;
        List<String> records = new ArrayList<>(outputFields.size());
        String printerKey = context.getId() + "-" + node.getNodeId();
        //printer单例维护
        CSVPrinter printer = cachePrinter.get(printerKey);
        String[] headers = new String[outputFields.size()];
        for(int i=0;i<headers.length;i++){
            headers[i] = outputFields.get(i).getKey();
        }
        try {
            if (printer == null) {
                synchronized (cachePrinter) {
                    printer = cachePrinter.get(printerKey); //可能有另外的线程写好了，所以这里再次检查
                    if (printer == null) {
                        CSVFormat csvFormat = CSVFormat.DEFAULT.builder().
                                setHeader(headers)
                                .setSkipHeaderRecord(false)  //写入headers
                                .build(); //有这个表头的csv format
                        //创建目标文件夹和文件
                        File parentDir = new File(fileDir);
                        if (!parentDir.exists())
                            parentDir.mkdirs();
                        File file = FileUtils.buildFileName(parentDir.getPath(),csvName,".csv",0);
                        logger.info("写入文件:{}", file.getAbsolutePath());
                        //获取文件输出流
                        FileOutputStream os = new FileOutputStream(file);
                        String csvEncoding = node.getStringByKey(CSV_ENCODING, "utf-8");
                        OutputStreamWriter osw=new OutputStreamWriter(os,csvEncoding);
                        printer = new CSVPrinter(osw,csvFormat);
                        cachePrinter.put(printerKey,printer);
                    }
                }
            }
            for(int i=0;i<headers.length;i++){
                records.add(outputFields.get(i).getValue());
            }
            synchronized (printer){
                printer.printRecord(records);
            }
        } catch (IOException e) {
            logger.error("文件输出错误,异常信息:{}", e.getMessage(), e);
            ExceptionUtils.wrapAndThrow(e);
        }
    }

    @Override
    public String supportShape() {
        return "output";
    }

    @Override
    public void beforeListener() {

    }

    @Override
    public void afterListener() {
        SpiderContext ctx = SpiderContextHolder.get();
        //只清楚本context的CSV打印机
        this.releasePrinters(ctx.getId());
    }

    private void releasePrinters(String contextId) {
        for (Iterator<Map.Entry<String, CSVPrinter>> iterator = this.cachePrinter.entrySet().iterator(); iterator.hasNext(); ) {
            Map.Entry<String, CSVPrinter> entry = iterator.next();
            if (entry.getKey().contains(contextId)) {
                CSVPrinter printer = entry.getValue();
                if (printer != null) {
                    try {
                        printer.flush();
                        printer.close();
                        this.cachePrinter.remove(entry.getKey());
                    } catch (IOException e) {
                        logger.error("文件输出错误,异常信息:{}", e.getMessage(), e);

                    }
                }
            }
        }
    }
}
